{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TwOwlnLuRZXu"
      },
      "outputs": [],
      "source": [
        "# Copyright 2021 The Google Research Authors.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JaCqYOye9DLX"
      },
      "source": [
        "# NOTE\n",
        "#### Make sure that this notebook is using `smug` kernel\n",
        "#### we use Inception_v1 model from TF slim here. Our paper used a slightly different variant of this model without batch norm, so visualizations and results may differ. At the bottom, we also show examples for Inception_v3."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iuaZMP90GCRd"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "import saliency\n",
        "import tensorflow.compat.v1 as tf\n",
        "import tensorflow_hub as hub\n",
        "import tf_slim as slim\n",
        "tf.disable_eager_execution()\n",
        "\n",
        "if not os.path.exists('models/research/slim'):\n",
        "  !git clone https://github.com/tensorflow/models/\n",
        "\n",
        "if not os.path.exists('inception_v1_2016_08_28.tar.gz'):\n",
        "  !wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz\n",
        "  !tar -xvzf inception_v1_2016_08_28.tar.gz\n",
        "\n",
        "old_cwd = os.getcwd()\n",
        "os.chdir('models/research/slim')\n",
        "from nets import inception_v1\n",
        "os.chdir(old_cwd)\n",
        "\n",
        "os.chdir('../')\n",
        "from smug_saliency import masking\n",
        "from smug_saliency import utils\n",
        "os.chdir('smug_saliency/')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "10LcAJ_2VjmN"
      },
      "outputs": [],
      "source": [
        "run_params_inception_v1 = masking.RunParams(**{\n",
        "  'model_type': 'cnn',\n",
        "  \n",
        "  # The following parameters pertain to the pre-trained model.\n",
        "  \n",
        "  # model_path is the path to the frozen tensorflow graph. It usually\n",
        "  # has a '.pb' extension. To load such a graph utils.restore_model\n",
        "  # function can be used. If a frozen model is unavailable then the\n",
        "  # model_path is set to '' and a custom load_model function should\n",
        "  # used for example restore_inception_v1 (below).\n",
        "  'model_path': '',\n",
        "  'image_placeholder_shape': (1, 224, 224, 3),\n",
        "  'padding': (2, 3),\n",
        "  'strides': 2,\n",
        "  'activations': None,\n",
        "  # range of input pixel values expected by the model.\n",
        "  'pixel_range': (0, 1),\n",
        "  # Find the appropriate tensornames by printing the tf ops using\n",
        "  # restore_inception_v1.\n",
        "  'tensor_names': {\n",
        "    'input': 'Placeholder:0',\n",
        "    'first_layer': 'InceptionV1/InceptionV1/Conv2d_1a_7x7/Conv2D:0',\n",
        "    'first_layer_relu': 'InceptionV1/InceptionV1/Conv2d_1a_7x7/Relu:0',\n",
        "    'logits': 'InceptionV1/Logits/SpatialSqueeze:0',\n",
        "    'softmax': 'InceptionV1/Logits/Predictions/Softmax:0',\n",
        "    'weights_layer_1': 'InceptionV1/InceptionV1/Conv2d_1a_7x7/Conv2D/ReadVariableOp:0',\n",
        "  }\n",
        "})\n",
        "\n",
        "def restore_inception_v1(model_path='./inception_v1.ckpt',\n",
        "                         print_ops=False):\n",
        "  \"\"\"Restores a tensorflow model from a checkpoint and returns it.\n",
        "\n",
        "  Args:\n",
        "    model_path: string, path to a tensorflow frozen graph.\n",
        "    print_ops: bool, prints operations in a tensorflow graph if true.\n",
        "\n",
        "  Returns:\n",
        "    session: tf.Session, tensorflow session with the loaded neural network.\n",
        "    graph: tensorflow graph corresponding to the tensorflow session.\n",
        "  \"\"\"\n",
        "  graph = tf.Graph()\n",
        "  with graph.as_default():\n",
        "    images = tf.placeholder(tf.float32, shape=(None, 224, 224, 3))\n",
        "    with slim.arg_scope(inception_v1.inception_v1_arg_scope()):\n",
        "      _, end_points = inception_v1.inception_v1(images, is_training=False, num_classes=1001)\n",
        "\n",
        "      # Restore the checkpoint\n",
        "      session = tf.Session(graph=graph)\n",
        "      saver = tf.train.Saver()\n",
        "      saver.restore(session, model_path)\n",
        "\n",
        "  # Find the appropriate tensornames by printing the tf ops.\n",
        "  # These tensornames are required to construct run_params.\n",
        "  if print_ops:\n",
        "    for op in graph.get_operations():\n",
        "      print(\"name:\", op.name)\n",
        "      print('inputs:')\n",
        "      for ip in op.inputs:\n",
        "        print(ip)\n",
        "      print('outputs:', op.outputs)\n",
        "      print('----\\n')\n",
        "  return session, graph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-EKQE-hrFUmV"
      },
      "outputs": [],
      "source": [
        "# Print the name of the tensors so as to construct\n",
        "# run_params_inception_v1.tensor_names\n",
        "restore_inception_v1(print_ops=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KCm2bFUG7mDr"
      },
      "source": [
        "#### Ensure that the first layer weights and biases are indeed correct"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nhtrhm1OFUmj"
      },
      "outputs": [],
      "source": [
        "def verify_first_layer_conv_weights(run_params, restore_model):\n",
        "  \"\"\"Performs convolution for the first layer using nested for loop\n",
        "  and checks that this is equal to the first layer conv weights.\"\"\"\n",
        "  image = utils.process_model_input(\n",
        "    np.random.random(run_params.image_placeholder_shape[1:]),\n",
        "    run_params.pixel_range)\n",
        "  session, _ = restore_model()\n",
        "  output_tensors = session.run(\n",
        "    run_params.tensor_names,\n",
        "    feed_dict={run_params.tensor_names['input']: [image]})\n",
        "  session.close()\n",
        "  if 'biases_layer_1' in run_params.tensor_names:\n",
        "    kernel_biases = output_tensors['biases_layer_1']\n",
        "  else:\n",
        "    kernel_biases = np.zeros(output_tensors['weights_layer_1'].shape[-1])\n",
        "  \n",
        "  # Computes the convoluion using nested for loop.\n",
        "  for_loop_convolution = utils.smt_convolution(\n",
        "      input_activation_maps=np.moveaxis(image, -1, 0),\n",
        "      kernels=output_tensors['weights_layer_1'],\n",
        "      kernel_biases=kernel_biases,\n",
        "      padding=run_params.padding,\n",
        "      strides=run_params.strides)\n",
        "  for_loop_convolution = np.moveaxis(\n",
        "      np.array(for_loop_convolution), 0, -1)\n",
        "  if np.mean(np.abs(output_tensors['first_layer'][0]\n",
        "                    - for_loop_convolution)) \u003e 1e-6:\n",
        "    print('The supplied names of the tensors is wrong.')\n",
        "    assert False\n",
        "  else:\n",
        "    print('Tensor names in run_params is consistent.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mV60QtFXFUmj",
        "outputId": "96a0d8a8-91ad-461a-ece1-0f3702c50e66"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Tensor names in run_params is consistent.\n"
          ]
        }
      ],
      "source": [
        "verify_first_layer_conv_weights(run_params_inception_v1,\n",
        "                                restore_inception_v1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FqgvHJgRFUmk"
      },
      "outputs": [],
      "source": [
        "def _get_saliency_maps(image, run_params, restore_model,\n",
        "                       top_k=3000, window_size=3):\n",
        "    tf.reset_default_graph()\n",
        "    image = utils.process_model_input(image, run_params.pixel_range)\n",
        "    restored_sess, restored_graph = restore_model()\n",
        "    input_tensor = restored_graph.get_tensor_by_name(\n",
        "        run_params.tensor_names['input'])\n",
        "    label_index = np.argmax(restored_sess.run(\n",
        "        run_params.tensor_names['softmax'],\n",
        "        feed_dict={input_tensor: [image]}))\n",
        "    ig_saliency_map = saliency.core.VisualizeImageGrayscale(\n",
        "        masking.get_saliency_map(\n",
        "            session=restored_sess,\n",
        "            features=image,\n",
        "            saliency_method='integrated_gradients',\n",
        "            label=label_index,\n",
        "            input_tensor_name=run_params.tensor_names['input'],\n",
        "            output_tensor_name=run_params.tensor_names['softmax'],\n",
        "            graph=restored_graph))\n",
        "    restored_sess, restored_graph = restore_model()\n",
        "    no_minimization_mask = utils.scale_saliency_map(\n",
        "        masking.get_no_minimization_mask(\n",
        "            image=image,\n",
        "            label_index=label_index,\n",
        "            run_params=run_params,\n",
        "            top_k=top_k,\n",
        "            session=restored_sess,\n",
        "            graph=restored_graph),\n",
        "        method='smug')\n",
        "    restored_sess, restored_graph = restore_model()\n",
        "    result = masking.find_mask_first_layer(\n",
        "        image=image,\n",
        "        label_index=label_index,\n",
        "        run_params=run_params,\n",
        "        window_size=window_size,\n",
        "        score_method='integrated_gradients',\n",
        "        top_k=top_k,\n",
        "        gamma=0.0,\n",
        "        timeout=3600,\n",
        "        session=restored_sess,\n",
        "        graph=restored_graph)\n",
        "    smug_mask = result['masks'][0].reshape(\n",
        "      run_params.image_placeholder_shape)[0, :, :, 0]\n",
        "    return (smug_mask * no_minimization_mask, no_minimization_mask,\n",
        "            ig_saliency_map)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OyfzQUFyFUmk"
      },
      "outputs": [],
      "source": [
        "def _get_saliency_params(image, saliency_map, run_params, restore_model):\n",
        "  tf.reset_default_graph()\n",
        "  session, _ = restore_model()\n",
        "  saliency_score = utils.calculate_saliency_score(\n",
        "    run_params=run_params,\n",
        "    image=image,\n",
        "    saliency_map=saliency_map,\n",
        "    session=session)\n",
        "  if saliency_score is None:\n",
        "    return None, None\n",
        "  return (saliency_score['saliency_score'],\n",
        "          saliency_score['crop_mask'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aJL5fQycFUmk"
      },
      "outputs": [],
      "source": [
        "def plot_saliency_maps(image, run_params, restore_model, window_size,\n",
        "                       show_bounding_box=False):\n",
        "    smug_saliency, no_minimization_saliency, ig_saliency = _get_saliency_maps(\n",
        "      image=image, restore_model=restore_model, run_params=run_params,\n",
        "      window_size=window_size)\n",
        "    smug_saliency_score, smug_crop_mask = _get_saliency_params(\n",
        "        image, smug_saliency, run_params, restore_model)\n",
        "    (no_minimization_saliency_score,\n",
        "     no_minimization_crop_mask) = _get_saliency_params(\n",
        "        image, no_minimization_saliency, run_params, restore_model)\n",
        "    ig_saliency_score, ig_crop_mask = _get_saliency_params(\n",
        "        image, ig_saliency, run_params, restore_model)\n",
        "    if smug_saliency_score is None or no_minimization_saliency_score is None:\n",
        "        return\n",
        "    fig=plt.figure(figsize=(10, 10))\n",
        "    fig.add_subplot(2, 2, 1)\n",
        "    plt.imshow(image)\n",
        "    plt.title('image')\n",
        "    utils.remove_ticks()\n",
        "\n",
        "    fig.add_subplot(2, 2, 2)\n",
        "    plt.imshow(smug_saliency, cmap='RdBu_r')    \n",
        "    plt.title(f'SMUG score:{smug_saliency_score:.2f}')\n",
        "    if show_bounding_box:\n",
        "      utils.show_bounding_box(smug_crop_mask)\n",
        "    utils.remove_ticks()\n",
        "\n",
        "    fig.add_subplot(2, 2, 3)\n",
        "    plt.imshow(no_minimization_saliency, cmap='RdBu_r')    \n",
        "    plt.title(f'SMUG_BASE score:{no_minimization_saliency_score:.2f}')\n",
        "    if show_bounding_box:\n",
        "      utils.show_bounding_box(no_minimization_crop_mask)\n",
        "    utils.remove_ticks()\n",
        "\n",
        "    fig.add_subplot(2, 2, 4)\n",
        "    plt.imshow(ig_saliency, cmap='RdBu_r')\n",
        "    plt.title(f'IG {ig_saliency_score:.2f}')\n",
        "    if show_bounding_box:\n",
        "      utils.show_bounding_box(ig_crop_mask)\n",
        "    utils.remove_ticks()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jGje_qFhFUml"
      },
      "outputs": [],
      "source": [
        "image = np.array(Image.open(open('tabby.jpg', 'rb')))\n",
        "tabby = (255 * np.ones((299, 299, 3))).astype(int)\n",
        "tabby[:224, :224, :3] = image\n",
        "print(tabby.shape)\n",
        "plt.imshow(tabby)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p7zImFbMFUml"
      },
      "outputs": [],
      "source": [
        "plot_saliency_maps(tabby[:224, :224, :],\n",
        "                   run_params_inception_v1,\n",
        "                   restore_inception_v1,\n",
        "                   window_size=4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2qtVR1bC9ndp"
      },
      "source": [
        "### Inception v3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "42Momsw9FUmm",
        "scrolled": false
      },
      "outputs": [],
      "source": [
        "# Note that most of the IG attributions lie at the edge of the cat.\n",
        "# While SMUG and SMUG_BASE highlight the facial features of the cat.\n",
        "# This observation has been explained in sec. 5.2\n",
        "# of https://arxiv.org/pdf/2006.16322.pdf\n",
        "\n",
        "run_params_inception_v3 = masking.RunParams(**{\n",
        "  'model_path': '',\n",
        "  'image_placeholder_shape': (1, 299, 299, 3),\n",
        "  'model_type': 'cnn',\n",
        "  'padding': (0, 0),\n",
        "  'strides': 2,\n",
        "  'activations': None,\n",
        "  'pixel_range': (-1, 1),\n",
        "  # Find the appropriate tensornames by printing the tf ops in\n",
        "  # restore_inception_v3.\n",
        "  'tensor_names': {\n",
        "    # Ideally the input tensor to inception v3 is 'module/hub_input/images:0'\n",
        "    # Instead we choose the tensor 'module/hub_input/Sub:0' because\n",
        "    # the input to the model has pixel values between (0, 1) and it\n",
        "    # is scaled between (-1, 1) and fed to the subsequent network.\n",
        "    # The scaled version of the image is denoted by the tensor\n",
        "    # 'module/hub_input/Sub:0'. Because we utils.find_mask_first_layer\n",
        "    # assumes that the convolution is performed directly on the input image\n",
        "    # withoout any rescaling, we feed input to the network via \n",
        "    # 'module/hub_input/Sub:0' tensor.\n",
        "    'input': 'module/hub_input/Sub:0',\n",
        "    'first_layer': 'module/InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D:0',\n",
        "    'first_layer_relu': 'module/InceptionV3/InceptionV3/Conv2d_1a_3x3/Relu:0',\n",
        "    'logits': 'module/InceptionV3/Logits/SpatialSqueeze:0',\n",
        "    'softmax': 'module/InceptionV3/Predictions/Softmax:0',\n",
        "    'weights_layer_1': 'module/InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D/ReadVariableOp:0',\n",
        "  }\n",
        "})\n",
        "\n",
        "def restore_inception_v3(model_path=('https://tfhub.dev/google/imagenet/'\n",
        "                                     'inception_v3/classification/1'),\n",
        "                         print_ops=False):\n",
        "  \"\"\"Restores a tensorflow model from a checkpoint and returns it.\n",
        "\n",
        "  Args:\n",
        "    model_path: string, path to a tensorflow frozen graph.\n",
        "    print_ops: bool, prints operations in a tensorflow graph if true.\n",
        "\n",
        "  Returns:\n",
        "    session: tf.Session, tensorflow session with the loaded neural network.\n",
        "    graph: tensorflow graph corresponding to the tensorflow session.\n",
        "  \"\"\"\n",
        "  graph = tf.Graph()\n",
        "  session = tf.Session(graph=graph)\n",
        "  with graph.as_default():\n",
        "    hub.Module(model_path)\n",
        "    session.run(tf.global_variables_initializer())\n",
        "    session.run(tf.tables_initializer())\n",
        "\n",
        "  # Find the appropriate tensornames by printing the tf ops.\n",
        "  # These tensornames are required to construct run_params.\n",
        "  if print_ops:\n",
        "    for op in graph.get_operations():\n",
        "      print(\"name:\", op.name)\n",
        "      print('inputs:')\n",
        "      for ip in op.inputs:\n",
        "        print(ip)\n",
        "      print('outputs:', op.outputs)\n",
        "      print('----\\n')\n",
        "  return session, graph\n",
        "\n",
        "restore_inception_v3(print_ops=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v7y3T_-hFUmm"
      },
      "outputs": [],
      "source": [
        "verify_first_layer_conv_weights(run_params_inception_v3,\n",
        "                                restore_inception_v3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Uz77-0w29uPF"
      },
      "outputs": [],
      "source": [
        "plot_saliency_maps(tabby,\n",
        "                   run_params_inception_v3,\n",
        "                   restore_inception_v3,\n",
        "                   window_size=3)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "image_saliency.ipynb",
      "provenance": [
        {
          "file_id": "1AhdqlDFsFrs3ctHx-Mz2N03KPEKx5pVW",
          "timestamp": 1616167425955
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
